"""
Phase 6: Robustness Checks
============================
Subsamples, alternative measures, weak instrument diagnostics, OLS vs. IV comparison.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)


OECD_COUNTRIES = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHE', 'CHL', 'COL', 'CRI', 'CZE', 'DEU',
    'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GRC', 'HUN', 'IRL', 'ISL',
    'ISR', 'ITA', 'JPN', 'KOR', 'LTU', 'LUX', 'LVA', 'MEX', 'NLD', 'NOR',
    'NZL', 'POL', 'PRT', 'SVK', 'SVN', 'SWE', 'TUR', 'USA',
}
FINANCIAL_CENTERS = {'LUX', 'IRL', 'HKG', 'SGP', 'NLD', 'CHE'}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def run_model(df, y_var, x_vars, label):
    """Run PanelGLS, return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[[c for c in cols if c in df.columns]].dropna()
    actual_x = [v for v in x_vars if v in sub.columns]
    if len(sub) < 30 or len(actual_x) == 0:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var].values, sub[actual_x].values,
            sub['iso3'].values, sub['year'].values)

    print(f"  {label}: N={gls.n_obs}, R²={gls.r_squared:.4f}")
    row = {'model': label, 'y_var': y_var, 'n_obs': gls.n_obs,
           'n_countries': gls.n_countries, 'r_squared': gls.r_squared}
    for i, name in enumerate(actual_x):
        row[f'{name}_coef'] = gls.beta[i]
        row[f'{name}_se'] = gls.se[i]
        row[f'{name}_p'] = gls.pvalues[i]
    return row


def run_first_stage(df, endog_var, instrument_var, controls, label):
    """Run first stage and return diagnostics."""
    cols = [endog_var, instrument_var] + controls + ['iso3', 'year']
    sub = df[[c for c in cols if c in df.columns]].dropna()
    actual_controls = [c for c in controls if c in sub.columns]
    if len(sub) < 50:
        return None

    x_vars = [instrument_var] + actual_controls
    gls = PanelGLS()
    gls.fit(sub[endog_var].values, sub[x_vars].values,
            sub['iso3'].values, sub['year'].values)

    f_stat = (gls.beta[0] / gls.se[0]) ** 2
    return {
        'model': label,
        'n_obs': gls.n_obs,
        'first_stage_F': f_stat,
        'first_stage_coef': gls.beta[0],
        'first_stage_se': gls.se[0],
        'first_stage_p': gls.pvalues[0],
        'partial_R2': gls.r_squared,
    }


def main():
    print("=" * 70)
    print("PHASE 6: ROBUSTNESS CHECKS")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    results = []
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']
    y_var = 'delta_log_kl'
    x_demo = ['Z_1', 'Z_2', 'Z_3']

    if y_var not in df.columns:
        print(f"ERROR: {y_var} not in panel. Check Phase 1 output.")
        return

    # ── 1. Japan exclusion ──
    print("\n--- 1. Japan Exclusion ---")
    df_no_jpn = df[df['iso3'] != 'JPN']
    r = run_model(df_no_jpn, y_var, x_demo + controls, 'Excl. Japan')
    if r: results.append(r)

    # ── 2. OECD vs. non-OECD ──
    print("\n--- 2. OECD vs. Non-OECD ---")
    df_oecd = df[df['iso3'].isin(OECD_COUNTRIES)]
    df_nooecd = df[~df['iso3'].isin(OECD_COUNTRIES)]
    r = run_model(df_oecd, y_var, x_demo + controls, 'OECD Only')
    if r: results.append(r)
    r = run_model(df_nooecd, y_var, x_demo + controls, 'Non-OECD')
    if r: results.append(r)

    # ── 3. Pre/Post GFC ──
    print("\n--- 3. Pre/Post GFC ---")
    df_pre = df[df['year'] < 2008]
    df_post = df[df['year'] >= 2010]
    r = run_model(df_pre, y_var, x_demo + controls, 'Pre-GFC (<2008)')
    if r: results.append(r)
    r = run_model(df_post, y_var, x_demo + controls, 'Post-GFC (≥2010)')
    if r: results.append(r)

    # ── 4. Alternative TFP ──
    print("\n--- 4. Alternative TFP ---")
    for tfp_var, tfp_label in [('delta_log_tfp', 'TFP (ctfp)'),
                                ('delta_log_rtfpna', 'TFP (rtfpna)')]:
        if tfp_var in df.columns:
            r = run_model(df, tfp_var, x_demo + controls, f'Z → {tfp_label}')
            if r: results.append(r)

    # ── 5. Net vs. gross flows ──
    print("\n--- 5. Net vs. Gross Flows ---")
    for flow_var, flow_label in [('ca_gdp', 'Net (CA/GDP)'),
                                  ('gross_liab_gdp', 'Gross Liabilities/GDP')]:
        if flow_var in df.columns:
            r = run_model(df, flow_var, x_demo + controls, f'Z → {flow_label}')
            if r: results.append(r)

    # ── 6. Exclude financial centers ──
    print("\n--- 6. Exclude Financial Centers ---")
    df_nofc = df[~df['iso3'].isin(FINANCIAL_CENTERS)]
    r = run_model(df_nofc, y_var, x_demo + controls, 'Excl. Financial Centers')
    if r: results.append(r)

    # ── 7. Alternative instruments ──
    print("\n--- 7. Alternative Instrument Specifications ---")
    # Demographics-only gravity (no GDP in instrument)
    # We test the existing instrument variable
    if 'log_predicted_demo_inflows' in df.columns:
        r = run_model(df, y_var,
                      ['log_predicted_demo_inflows'] + controls,
                      'Reduced Form: Demo-only Gravity → K/L')
        if r: results.append(r)

    if 'log_predicted_total_inflows' in df.columns:
        r = run_model(df, y_var,
                      ['log_predicted_total_inflows'] + controls,
                      'Reduced Form: Full Gravity → K/L')
        if r: results.append(r)

    # ── 8. Weak instrument diagnostics ──
    print("\n--- 8. Weak Instrument Diagnostics ---")
    if 'log_predicted_demo_inflows' in df.columns and 'log_total_portfolio_inflows' in df.columns:
        # Full sample
        fs = run_first_stage(df, 'log_total_portfolio_inflows',
                             'log_predicted_demo_inflows', controls,
                             'Full Sample')
        if fs:
            results.append(fs)
            print(f"  First-stage F = {fs['first_stage_F']:.2f} "
                  f"(>10 threshold: {'PASS' if fs['first_stage_F'] > 10 else 'FAIL'})")

        # OECD subsample
        fs_oecd = run_first_stage(df_oecd, 'log_total_portfolio_inflows',
                                   'log_predicted_demo_inflows', controls,
                                   'OECD First Stage')
        if fs_oecd:
            results.append(fs_oecd)
            print(f"  OECD First-stage F = {fs_oecd['first_stage_F']:.2f}")

    # ── 9. OLS vs. IV comparison table ──
    print("\n--- 9. OLS vs. IV Comparison ---")
    # We'll compile from phase 4 results if available
    p4_path = OUT_TABLES / "flow_outcomes_results.csv"
    if p4_path.exists():
        p4 = pd.read_csv(p4_path)
        print(f"  Phase 4 results loaded: {len(p4)} models")

        # Extract OLS and IV results for comparison
        ols_models = p4[p4['model'].str.contains('OLS', na=False)]
        iv_models = p4[p4['model'].str.contains('IV:', na=False)]
        rf_models = p4[p4['model'].str.contains('Reduced', na=False)]

        with open(OUT_TABLES / "ols_iv_comparison.md", 'w') as f:
            f.write("# OLS vs. IV Comparison\n\n")
            f.write("| Outcome | OLS β | OLS p | IV β | IV p | 1st-F | RF β | RF p |\n")
            f.write("|---------|-------|-------|------|------|-------|------|------|\n")
            for y in ['delta_log_kl', 'gross_fixed_investment_gdp', 'delta_log_tfp',
                      'rgdp_growth', 'mpk_proxy']:
                ols_row = ols_models[ols_models['y_var'] == y]
                iv_row = iv_models[iv_models['y_var'] == y]
                rf_row = rf_models[rf_models['y_var'] == y]

                # Get flow coefficient from OLS
                ols_b = ols_p = iv_b = iv_p = iv_f = rf_b = rf_p = ''
                if len(ols_row) > 0:
                    row = ols_row.iloc[0]
                    for k in row.index:
                        if k.endswith('_coef') and 'log_' in k:
                            ols_b = f"{row[k]:.4f}"
                            ols_p = f"{row[k.replace('_coef', '_p')]:.4f}"
                            break
                if len(iv_row) > 0:
                    row = iv_row.iloc[0]
                    iv_b = f"{row.get('iv_coef', ''):.4f}" if pd.notna(row.get('iv_coef')) else ''
                    iv_p = f"{row.get('iv_p_bootstrap', ''):.4f}" if pd.notna(row.get('iv_p_bootstrap')) else ''
                    iv_f = f"{row.get('first_stage_F', ''):.1f}" if pd.notna(row.get('first_stage_F')) else ''
                if len(rf_row) > 0:
                    row = rf_row.iloc[0]
                    for k in row.index:
                        if k.endswith('_coef') and 'log_' in k:
                            rf_b = f"{row[k]:.4f}"
                            rf_p = f"{row[k.replace('_coef', '_p')]:.4f}"
                            break

                f.write(f"| {y} | {ols_b} | {ols_p} | {iv_b} | {iv_p} | {iv_f} | {rf_b} | {rf_p} |\n")
            f.write("\n")

    # ── Save all robustness results ──
    results_df = pd.DataFrame(results)
    results_df.to_csv(OUT_TABLES / "robustness_results.csv", index=False)

    # Formatted table
    with open(OUT_TABLES / "robustness.md", 'w') as f:
        f.write("# Table 6: Robustness Checks\n\n")
        f.write(f"## Dependent Variable: Δlog(K/L) — Capital Deepening\n\n")
        f.write("| Specification | N | R² | Z₁ β | Z₁ p | Z₂ β | Z₂ p | Z₃ β | Z₃ p |\n")
        f.write("|---------------|---|-----|------|------|------|------|------|------|\n")

        for r in results:
            if 'Z_1_coef' not in r:
                continue
            z1b = f"{r['Z_1_coef']:.4f}{stars(r['Z_1_p'])}"
            z1p = f"{r['Z_1_p']:.4f}"
            z2b = f"{r.get('Z_2_coef', 0):.4f}{stars(r.get('Z_2_p', 1))}"
            z2p = f"{r.get('Z_2_p', 1):.4f}"
            z3b = f"{r.get('Z_3_coef', 0):.4f}{stars(r.get('Z_3_p', 1))}"
            z3p = f"{r.get('Z_3_p', 1):.4f}"
            f.write(f"| {r['model']} | {r['n_obs']} | {r['r_squared']:.4f} "
                    f"| {z1b} | {z1p} | {z2b} | {z2p} | {z3b} | {z3p} |\n")

        f.write("\n## Weak Instrument Diagnostics\n\n")
        f.write("| Sample | N | First-stage F | Instrument β | p |\n")
        f.write("|--------|---|---------------|-------------|---|\n")
        for r in results:
            if 'first_stage_F' in r:
                f.write(f"| {r['model']} | {r['n_obs']} | {r['first_stage_F']:.2f} "
                        f"| {r['first_stage_coef']:.4f} | {r['first_stage_p']:.4f} |\n")

        f.write("\n*Stock-Yogo critical value for 10% maximal bias: F > 16.38 (one instrument).*\n")
        f.write("*p<0.1, **p<0.05, ***p<0.01\n")

    print("\nPhase 6 complete.")


if __name__ == '__main__':
    main()
